{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "fFjof1NgAJwu"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "# Use Apache Beam and Bigtable to enrich data\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HrCtxslBGK8Z"
      },
      "source": [
        "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [Bigtable](https://cloud.google.com/bigtable/docs/overview). The enrichment transform is an Apache Beam turnkey transform that lets you enrich data by using a key-value lookup. This transform has the following features:\n",
        "\n",
        "- The transform has a built-in Apache Beam handler that interacts with Bigtable to get data to use in the enrichment.\n",
        "- The enrichment transform uses client-side throttling to manage rate limiting the requests. The requests are exponentially backed off with a default retry strategy. You can configure rate limiting to suit your use case."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ltn5zrBiGS9C"
      },
      "source": [
        "This notebook demonstrates the following ecommerce use case:\n",
        "\n",
        "A stream of online transaction from [Pub/Sub](https://cloud.google.com/pubsub/docs/guides) contains the following fields: `sale_id`, `product_id`, `customer_id`, `quantity`, and `price`. Additional customer demographic data is stored in a separate Bigtable cluster. The demographic data is used to enrich the event stream from Pub/Sub. Then, the enriched data is used to predict the next product to recommended to a customer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gVCtGOKTHMm4"
      },
      "source": [
        "## Before you begin\n",
        "Set up your environment and download dependencies."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YDHPlMjZRuY0"
      },
      "source": [
        "### Install Apache Beam\n",
        "To use the enrichment transform with the built-in Bigtable handler, install the Apache Beam SDK version 2.54.0 or later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jBakpNZnAhqk"
      },
      "outputs": [],
      "source": [
        "!pip install torch\n",
        "!pip install apache_beam[interactive,gcp]==2.54.0 --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SiJii48A2Rnb"
      },
      "outputs": [],
      "source": [
        "import datetime\n",
        "import json\n",
        "import math\n",
        "\n",
        "from typing import Any\n",
        "from typing import Dict\n",
        "\n",
        "import torch\n",
        "from google.cloud import pubsub_v1\n",
        "from google.cloud.bigtable import Client\n",
        "from google.cloud.bigtable import column_family\n",
        "\n",
        "import apache_beam as beam\n",
        "import apache_beam.runners.interactive.interactive_beam as ib\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
        "from apache_beam.options import pipeline_options\n",
        "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
        "from apache_beam.transforms.enrichment import Enrichment\n",
        "from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X80jy3FqHjK4"
      },
      "source": [
        "### Authenticate with Google Cloud\n",
        "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook.\n",
        "To prepare for this step, replace `<PROJECT_ID>`, `<INSTANCE_ID>`, and `<TABLE_ID>` with the appropriate values for your setup. These fields are used with Bigtable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wEXucyi2liij"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"<PROJECT_ID>\" # @param {type:'string'}\n",
        "INSTANCE_ID = \"<INSTANCE_ID>\" # @param {type:'string'}\n",
        "TABLE_ID = \"<TABLE_ID>\" # @param {type:'string'}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kz9sccyGBqz3"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user(project_id=PROJECT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RpqZFfFfA_Dt"
      },
      "source": [
        "### Train the model\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8cUpV7mkB_xE"
      },
      "source": [
        "Create sample data by using the format `[product_id, quantity, price, customer_id, customer_location, recommend_product_id]`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TpxDHGObBEsj"
      },
      "outputs": [],
      "source": [
        "data = [\n",
        "    [3, 5, 127, 9, 'China', 7], [1, 6, 167, 5, 'Peru', 4], [5, 4, 91, 2, 'USA', 8], [7, 2, 52, 1, 'India', 4], [1, 8, 118, 3, 'UK', 8], [4, 6, 132, 8, 'Mexico', 2],\n",
        "    [6, 3, 154, 6, 'Brazil', 3], [4, 7, 163, 1, 'India', 7], [5, 2, 80, 4, 'Egypt', 9], [9, 4, 107, 7, 'Bangladesh', 1], [2, 9, 192, 8, 'Mexico', 4], [4, 5, 116, 5, 'Peru', 8],\n",
        "    [8, 1, 195, 1, 'India', 7], [8, 6, 153, 5, 'Peru', 1], [5, 3, 120, 6, 'Brazil', 2], [2, 7, 187, 7, 'Bangladesh', 4], [1, 8, 103, 6, 'Brazil', 8], [2, 9, 181, 1, 'India', 8],\n",
        "    [6, 5, 166, 3, 'UK', 5], [3, 4, 115, 8, 'Mexico', 1], [4, 7, 170, 4, 'Egypt', 2], [9, 3, 141, 7, 'Bangladesh', 3], [9, 3, 157, 1, 'India', 2], [7, 6, 128, 9, 'China', 1],\n",
        "    [1, 8, 102, 3, 'UK', 4], [5, 2, 107, 4, 'Egypt', 6], [6, 5, 164, 8, 'Mexico', 9], [4, 7, 188, 5, 'Peru', 1], [8, 1, 184, 1, 'India', 2], [8, 6, 198, 2, 'USA', 5],\n",
        "    [5, 3, 105, 6, 'Brazil', 7], [2, 7, 162, 7, 'Bangladesh', 7], [1, 8, 133, 9, 'China', 3], [2, 9, 173, 1, 'India', 7], [6, 5, 183, 5, 'Peru', 8], [3, 4, 191, 3, 'UK', 6],\n",
        "    [4, 7, 123, 2, 'USA', 5], [9, 3, 159, 8, 'Mexico', 2], [9, 3, 146, 4, 'Egypt', 8], [7, 6, 194, 1, 'India', 8], [3, 5, 112, 6, 'Brazil', 1], [4, 6, 101, 7, 'Bangladesh', 2],\n",
        "    [8, 1, 192, 4, 'Egypt', 4], [7, 2, 196, 5, 'Peru', 6], [9, 4, 124, 9, 'China', 7], [3, 4, 129, 5, 'Peru', 6], [6, 3, 151, 8, 'Mexico', 9], [5, 7, 114, 7, 'Bangladesh', 4],\n",
        "    [4, 7, 175, 6, 'Brazil', 5], [1, 8, 121, 1, 'India', 2], [4, 6, 187, 2, 'USA', 5], [6, 5, 144, 9, 'China', 9], [9, 4, 103, 5, 'Peru', 3], [5, 3, 84, 3, 'UK', 1],\n",
        "    [3, 5, 193, 2, 'USA', 4], [4, 7, 135, 1, 'India', 1], [7, 6, 148, 8, 'Mexico', 8], [1, 6, 160, 5, 'Peru', 7], [8, 6, 155, 6, 'Brazil', 9], [5, 7, 183, 7, 'Bangladesh', 2],\n",
        "    [2, 9, 125, 4, 'Egypt', 4], [6, 3, 111, 9, 'China', 9], [5, 2, 132, 3, 'UK', 3], [4, 5, 104, 7, 'Bangladesh', 7], [2, 7, 177, 8, 'Mexico', 7]]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bQt1cB4-CSBd"
      },
      "outputs": [],
      "source": [
        "countries_to_id = {'India': 1, 'USA': 2, 'UK': 3, 'Egypt': 4, 'Peru': 5,\n",
        "                   'Brazil': 6, 'Bangladesh': 7, 'Mexico': 8, 'China': 9}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y0Duet4nCdN1"
      },
      "source": [
        "Preprocess the data:\n",
        "\n",
        "1.   Convert the lists to tensors.\n",
        "2.   Separate the features from the expected prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7TT1O7sBCaZN"
      },
      "outputs": [],
      "source": [
        "X = [torch.tensor(item[:4]+[countries_to_id[item[4]]], dtype=torch.float) for item in data]\n",
        "Y = [torch.tensor(item[-1], dtype=torch.float) for item in data]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q6wB_ZsXDjjd"
      },
      "source": [
        "Define a simple model that has five input features and predicts a single value."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nphNfhUnESES"
      },
      "outputs": [],
      "source": [
        "def build_model(n_inputs, n_outputs):\n",
        "  \"\"\"build_model builds and returns a model that takes\n",
        "  `n_inputs` features and predicts `n_outputs` value\"\"\"\n",
        "  return torch.nn.Sequential(\n",
        "      torch.nn.Linear(n_inputs, 8),\n",
        "      torch.nn.ReLU(),\n",
        "      torch.nn.Linear(8, 16),\n",
        "      torch.nn.ReLU(),\n",
        "      torch.nn.Linear(16, n_outputs))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_sBSzDllEmCz"
      },
      "source": [
        "Train the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CaYrplaPDayp"
      },
      "outputs": [],
      "source": [
        "model = build_model(n_inputs=5, n_outputs=1)\n",
        "\n",
        "loss_fn = torch.nn.MSELoss()\n",
        "optimizer = torch.optim.Adam(model.parameters())\n",
        "\n",
        "for epoch in range(1000):\n",
        "  print(f'Epoch {epoch}: ---')\n",
        "  optimizer.zero_grad()\n",
        "  for i in range(len(X)):\n",
        "    pred = model(X[i])\n",
        "    loss = loss_fn(pred, Y[i])\n",
        "    loss.backward()\n",
        "  optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_rJYv8fFFPYb"
      },
      "source": [
        "Save the model to the `STATE_DICT_PATH` variable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W4t260o9FURP"
      },
      "outputs": [],
      "source": [
        "STATE_DICT_PATH = './model.pth'\n",
        "\n",
        "torch.save(model.state_dict(), STATE_DICT_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ouMQZ4sC4zuO"
      },
      "source": [
        "### Set up the Bigtable table\n",
        "\n",
        "Create a sample Bigtable table for this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E7Y4ipuL5kFD"
      },
      "outputs": [],
      "source": [
        "# Connect to the Bigtable instance. If you don't have admin access, then drop `admin=True`.\n",
        "client = Client(project=PROJECT_ID, admin=True)\n",
        "instance = client.instance(INSTANCE_ID)\n",
        "\n",
        "# Create a column family.\n",
        "column_family_id = 'demograph'\n",
        "max_versions_rule = column_family.MaxVersionsGCRule(2)\n",
        "column_families = {column_family_id: max_versions_rule}\n",
        "\n",
        "# Create a table.\n",
        "table = instance.table(TABLE_ID)\n",
        "\n",
        "# You need admin access to use `.exists()`. If you don't have the admin access, then\n",
        "# comment out the if-else block.\n",
        "if not table.exists():\n",
        "  table.create(column_families=column_families)\n",
        "else:\n",
        "  print(\"Table %s already exists in %s:%s\" % (TABLE_ID, PROJECT_ID, INSTANCE_ID))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQLkSg3p7WAm"
      },
      "source": [
        "Add rows to the table for the enrichment example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "LI6oYkZ97Vtu",
        "outputId": "c72b28b5-8692-40f5-f8da-85622437d8f7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Inserted row for key: 1\n",
            "Inserted row for key: 2\n",
            "Inserted row for key: 3\n"
          ]
        }
      ],
      "source": [
        "# Define column names for the table.\n",
        "customer_id = 'customer_id'\n",
        "customer_name = 'customer_name'\n",
        "customer_location = 'customer_location'\n",
        "\n",
        "# The following data is sample data to insert into Bigtable.\n",
        "customers = [\n",
        "  {\n",
        "    'customer_id': 1, 'customer_name': 'Sam', 'customer_location': 'India'\n",
        "  },\n",
        "  {\n",
        "    'customer_id': 2, 'customer_name': 'John', 'customer_location': 'USA'\n",
        "  },\n",
        "  {\n",
        "    'customer_id': 3, 'customer_name': 'Travis', 'customer_location': 'UK'\n",
        "  },\n",
        "]\n",
        "\n",
        "for customer in customers:\n",
        "  row_key = str(customer[customer_id]).encode()\n",
        "  row = table.direct_row(row_key)\n",
        "  row.set_cell(\n",
        "    column_family_id,\n",
        "    customer_id.encode(),\n",
        "    str(customer[customer_id]),\n",
        "    timestamp=datetime.datetime.utcnow())\n",
        "  row.set_cell(\n",
        "    column_family_id,\n",
        "    customer_name.encode(),\n",
        "    customer[customer_name],\n",
        "    timestamp=datetime.datetime.utcnow())\n",
        "  row.set_cell(\n",
        "    column_family_id,\n",
        "    customer_location.encode(),\n",
        "    customer[customer_location],\n",
        "    timestamp=datetime.datetime.utcnow())\n",
        "  row.commit()\n",
        "  print('Inserted row for key: %s' % customer[customer_id])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pHODouJDwc60"
      },
      "source": [
        "### Publish messages to Pub/Sub\n",
        "\n",
        "Use the Pub/Sub Python client to publish messages.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QKCuwDioxw-f"
      },
      "outputs": [],
      "source": [
        "# Replace <TOPIC_NAME> with the name of your Pub/Sub topic.\n",
        "TOPIC = \"<TOPIC_NAME>\" # @param {type:'string'}\n",
        "\n",
        "# Replace <SUBSCRIPTION_PATH> with the subscription for your topic.\n",
        "SUBSCRIPTION = \"<SUBSCRIPTION_PATH>\" # @param {type:'string'}\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MaCJwaPexPKZ"
      },
      "outputs": [],
      "source": [
        "messages = [\n",
        "    {'sale_id': i, 'customer_id': i, 'product_id': i, 'quantity': i, 'price': i*100}\n",
        "    for i in range(1,4)\n",
        "  ]\n",
        "\n",
        "publisher = pubsub_v1.PublisherClient()\n",
        "topic_name = publisher.topic_path(PROJECT_ID, TOPIC)\n",
        "\n",
        "for message in messages:\n",
        "  data = json.dumps(message).encode('utf-8')\n",
        "  publish_future = publisher.publish(topic_name, data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zPSFEMm02omi"
      },
      "source": [
        "## Use the Bigtable enrichment handler\n",
        "\n",
        "The [`BigTableEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.54.0 and later."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K41xhvmA5yQk"
      },
      "source": [
        "Configure the `BigTableEnrichmentHandler` handler with the following required parameters:\n",
        "\n",
        "* `project_id`: the Google Cloud project ID for the Bigtable instance\n",
        "* `instance_id`: the instance name of the Bigtable cluster\n",
        "* `table_id`: the table ID of table containing relevant data\n",
        "* `row_key`: The field name from the input row that contains the row key to use when querying Bigtable.\n",
        "\n",
        "Optionally, you can use parameters to further configure the `BigTableEnrichmentHandler` handler. For more information about the available parameters, see the [enrichment handler module documentation](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#module-apache_beam.transforms.enrichment_handlers.bigtable)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yFMcaf8i7TbI"
      },
      "source": [
        "**Note:** When exceptions occur, by default, the logging severity is set to warning ([`ExceptionLevel.WARN`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.utils.html#apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel.WARN)).  To configure the severity to raise exceptions, set `exception_level` to [`ExceptionLevel.RAISE`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.utils.html#apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel.RAISE). To ignore exceptions, set `exception_level` to [`ExceptionLevel.QUIET`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.utils.html#apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel.QUIET).\n",
        "\n",
        "The following example demonstrates how to set the exception level in the `BigTableEnrichmentHandler` handler:\n",
        "\n",
        "```\n",
        "bigtable_handler = BigTableEnrichmentHandler(project_id=PROJECT_ID,\n",
        "                                             instance_id=INSTANCE_ID,\n",
        "                                             table_id=TABLE_ID,\n",
        "                                             row_key=row_key,\n",
        "                                             exception_level=ExceptionLevel.RAISE)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UEpjy_IsW4P4"
      },
      "source": [
        "The `row_key` parameter represents the field in input schema (`beam.Row`) that contains the row key for a row in the table.\n",
        "\n",
        "Starting with Apache Beam version 2.54.0, you can perform either of the following tasks when a table uses composite row keys:\n",
        "* Modify the input schema to contain the row key in the format required by Bigtable.\n",
        "* Use a custom enrichment handler. For more information, see the [example handler with composite row key support](https://www.toptal.com/developers/paste-gd/BYFGUL08#)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3dB26jhI45gd"
      },
      "outputs": [],
      "source": [
        "row_key = 'customer_id'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cr1j_DHK4gA4"
      },
      "outputs": [],
      "source": [
        "bigtable_handler = BigTableEnrichmentHandler(project_id=PROJECT_ID,\n",
        "                                             instance_id=INSTANCE_ID,\n",
        "                                             table_id=TABLE_ID,\n",
        "                                             row_key=row_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-Lvo8O2V-0Ey"
      },
      "source": [
        "## Use the enrichment transform\n",
        "\n",
        "To use the [enrichment transform](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.Enrichment), the enrichment handler parameter is the only required parameter."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xJTCfSmiV1kv"
      },
      "source": [
        "The following example demonstrates the code needed to add this transform to your pipeline.\n",
        "\n",
        "\n",
        "```\n",
        "with beam.Pipeline() as p:\n",
        "  output = (p\n",
        "            ...\n",
        "            | \"Enrich with BigTable\" >> Enrichment(bigtable_handler)\n",
        "            | \"RunInference\" >> RunInference(model_handler)\n",
        "            ...\n",
        "            )\n",
        "```\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### What is a Cross-Join?\n",
        "A cross-join is a Cartesian product operation where each row from one table is combined with every row from another table. It is useful when we want to create all possible combinations of two datasets.\n",
        "\n",
        "**Example:**\n",
        "Table A:\n",
        "  | A1 | A2 |\n",
        "  |----|----|\n",
        "  |  1 |  X |\n",
        "  |  2 |  Y |\n",
        "\n",
        "Table B:\n",
        "  | B1 | B2 |\n",
        "  |----|----|\n",
        "  | 10 |  P |\n",
        "  | 20 |  Q |\n",
        "\n",
        "**Result of Cross-Join:**\n",
        "  | A1 | A2 | B1 | B2 |\n",
        "  |----|----|----|----|\n",
        "  |  1 |  X | 10 |  P |\n",
        "  |  1 |  X | 20 |  Q |\n",
        "  |  2 |  Y | 10 |  P |\n",
        "  |  2 |  Y | 20 |  Q |\n",
        "\n",
        "Cross-joins can be computationally expensive for large datasets, so use them judiciously.\n",
        "\n",
        "By default, the enrichment transform performs a [`cross_join`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join). This join returns the enriched row with the following fields: `sale_id`, `customer_id`, `product_id`, `quantity`, `price`, and `customer_location`.\n",
        "\n",
        "To make a prediction when running the ecommerce example, however, the trained model needs the following fields: `product_id`, `quantity`, `price`, `customer_id`, and `customer_location`.\n",
        "\n",
        "Therefore, to get the required fields for the ecommerce example, design a custom join function that takes two dictionaries as input and returns an enriched row that include these fields.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8LnCnEPNIPtg"
      },
      "outputs": [],
      "source": [
        "def custom_join(left: Dict[str, Any], right: Dict[str, Any]):\n",
        "  enriched = {}\n",
        "  enriched['product_id'] = left['product_id']\n",
        "  enriched['quantity'] = left['quantity']\n",
        "  enriched['price'] = left['price']\n",
        "  enriched['customer_id'] = left['customer_id']\n",
        "  enriched['customer_location'] = right['demograph']['customer_location']\n",
        "  return beam.Row(**enriched)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fe3bIclV1jZ5"
      },
      "source": [
        "To provide a `lambda` function for using a custom join with the enrichment transform, see the following example.\n",
        "\n",
        "```\n",
        "with beam.Pipeline() as p:\n",
        "  output = (p\n",
        "            ...\n",
        "            | \"Enrich with BigTable\" >> Enrichment(bigtable_handler, join_fn=custom_join)\n",
        "            | \"RunInference\" >> RunInference(model_handler)\n",
        "            ...\n",
        "            )\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uilxdknE3ihO"
      },
      "source": [
        "Because the enrichment transform makes API calls to the remote service, use the `timeout` parameter to specify a timeout duration of 10 seconds:\n",
        "\n",
        "```\n",
        "with beam.Pipeline() as p:\n",
        "  output = (p\n",
        "            ...\n",
        "            | \"Enrich with BigTable\" >> Enrichment(bigtable_handler, join_fn=custom_join, timeout=10)\n",
        "            | \"RunInference\" >> RunInference(model_handler)\n",
        "            ...\n",
        "            )\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CX9Cqybu6scV"
      },
      "source": [
        "## Use the `PyTorchModelHandlerTensor` interface to run inference\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zy5Jl7_gLklX"
      },
      "source": [
        "Because the enrichment transform outputs data in the format `beam.Row`, to make it compatible with the [`PyTorchModelHandlerTensor`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.pytorch_inference.html#apache_beam.ml.inference.pytorch_inference.PytorchModelHandlerTensor) interface,  convert it to `torch.tensor`. Additionally, the enriched field `customer_location` is a `string` type, but the model requires a `float` type. Convert the `customer_location` field to a `float` type."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KBKoB06nL4LF"
      },
      "outputs": [],
      "source": [
        "def convert_row_to_tensor(element: beam.Row):\n",
        "  row_dict = element._asdict()\n",
        "  row_dict['customer_location'] = countries_to_id[row_dict['customer_location']]\n",
        "  return torch.tensor(list(row_dict.values()), dtype=torch.float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-tGHyB_vL3rJ"
      },
      "source": [
        "Initialize the model handler with the preprocessing function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VqUUEwcU-r2e"
      },
      "outputs": [],
      "source": [
        "model_handler = PytorchModelHandlerTensor(state_dict_path=STATE_DICT_PATH,\n",
        "                                          model_class=build_model,\n",
        "                                          model_params={'n_inputs':5, 'n_outputs':1}\n",
        "                                          ).with_preprocess_fn(convert_row_to_tensor)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vNHI4gVgNec2"
      },
      "source": [
        "Define a `DoFn` to format the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rkN-_Yf4Nlwy"
      },
      "outputs": [],
      "source": [
        "class PostProcessor(beam.DoFn):\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    print('Customer %d who bought product %d is recommended to buy product %d' % (element.example[3], element.example[0], math.ceil(element.inference[0])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0a1zerXycQ0z"
      },
      "source": [
        "## Run the pipeline\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WrwY0_gV_IDK"
      },
      "source": [
        "Configure the pipeline to run in streaming mode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t0425sYBsYtB"
      },
      "outputs": [],
      "source": [
        "options = pipeline_options.PipelineOptions()\n",
        "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DBNijQDY_dRe"
      },
      "source": [
        "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sRw9iL8pKP5O"
      },
      "outputs": [],
      "source": [
        "class DecodeBytes(beam.DoFn):\n",
        "  \"\"\"\n",
        "  The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n",
        "  First, decode the encoded string. Convert the output to\n",
        "  a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n",
        "  \"\"\"\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    element_dict = json.loads(element.decode('utf-8'))\n",
        "    yield beam.Row(**element_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xofUJym-_GuB"
      },
      "source": [
        "Use the following code to run the pipeline.\n",
        "\n",
        "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "St07XoibcQSb",
        "outputId": "34e0a603-fb77-455c-e40b-d15b672edeb2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Customer 1 who bought product 1 is recommended to buy product 3\n",
            "Customer 2 who bought product 2 is recommended to buy product 5\n",
            "Customer 3 who bought product 3 is recommended to buy product 7\n"
          ]
        }
      ],
      "source": [
        "with beam.Pipeline(options=options) as p:\n",
        "  _ = (p\n",
        "       | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=SUBSCRIPTION)\n",
        "       | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n",
        "       | \"Enrichment\" >> Enrichment(bigtable_handler, join_fn=custom_join, timeout=10)\n",
        "       | \"RunInference\" >> RunInference(model_handler)\n",
        "       | \"Format Output\" >> beam.ParDo(PostProcessor())\n",
        "       )\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "RpqZFfFfA_Dt"
      ],
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
